-
Notifications
You must be signed in to change notification settings - Fork 540
FA num splits option #2357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FA num splits option #2357
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR exposes the num_splits parameter for FlashAttention v2 and v3 backends, allowing users to control memory optimization during attention computation.
Key Changes:
- Added optional
num_splitsparameter toDotProductAttention.forward()method - Passes
num_splitsto both FlashAttention v2 and v3 backend implementations when provided - Parameter is conditionally added to kwargs only when not
None
Areas for Improvement:
- Missing parameter documentation in the docstring
- No version compatibility check for flash-attn (unlike other optional parameters like
window_sizeanddeterministic) - No tests demonstrating the new functionality
Confidence Score: 4/5
- This PR is safe to merge with minor documentation improvements recommended
- The implementation correctly follows the existing pattern for optional parameters in FlashAttention backends. The changes are minimal and well-scoped. However, the score is not 5 due to: (1) missing parameter documentation, (2) lack of version compatibility checks that other optional parameters have, and (3) no accompanying tests. These are quality-of-life improvements rather than critical issues.
- No files require special attention - the implementation is straightforward and follows existing patterns
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/backends.py | 4/5 | Added num_splits parameter to FlashAttention forward method and passes it to both FA v2 and FA v3 backends when provided |
| transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py | 4/5 | Added num_splits parameter to DotProductAttention forward signature and forwards it to FlashAttention backend |
Sequence Diagram
sequenceDiagram
participant User
participant DotProductAttention
participant FlashAttention
participant FlashAttn_v2
participant FlashAttn_v3
User->>DotProductAttention: forward(query, key, value, num_splits=...)
Note over DotProductAttention: New optional parameter
DotProductAttention->>FlashAttention: forward(..., num_splits=num_splits)
Note over FlashAttention: Check backend version
alt Flash Attention v2
FlashAttention->>FlashAttention: if num_splits is not None
FlashAttention->>FlashAttention: fa_optional_forward_kwargs["num_splits"] = num_splits
FlashAttention->>FlashAttn_v2: flash_attn_func(..., **fa_optional_forward_kwargs)
FlashAttn_v2-->>FlashAttention: output
else Flash Attention v3
FlashAttention->>FlashAttention: if num_splits is not None
FlashAttention->>FlashAttention: fa_3_optional_forward_kwargs["num_splits"] = num_splits
FlashAttention->>FlashAttn_v3: flash_attn_func(..., **fa_3_optional_forward_kwargs)
FlashAttn_v3-->>FlashAttention: output
end
FlashAttention-->>DotProductAttention: output
DotProductAttention-->>User: output
2 files reviewed, 3 comments
| inference_params: Optional[InferenceParams] = None, | ||
| pad_between_seqs: Optional[bool] = None, | ||
| fp8_output: Optional[bool] = False, | ||
| num_splits: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Missing documentation for the num_splits parameter in the docstring. Add a parameter description explaining what this controls and when it should be used (e.g., "Controls the number of splits for FlashAttention computation. Used for memory optimization.")
| if num_splits is not None: | ||
| # Forward optional split control to flash-attn if available | ||
| fa_optional_forward_kwargs["num_splits"] = num_splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Verify that the flash-attn version supports num_splits parameter. Unlike other optional parameters (e.g., window_size has fa_utils.v2_3_plus check, deterministic has fa_utils.v2_4_1_plus check), this parameter is added without a version guard.
| if num_splits is not None: | ||
| # Forward optional split control to flash-attn v3 if supported | ||
| fa_3_optional_forward_kwargs["num_splits"] = num_splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Same as FA v2: verify flash-attn v3 supports num_splits to avoid potential TypeError
Greptile OverviewGreptile SummaryThis PR adds support for controlling the Key Changes:
Issues Found:
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant DotProductAttention
participant FlashAttention
participant FA3/FA2
User->>DotProductAttention: forward(query, key, value, num_splits=2)
DotProductAttention->>DotProductAttention: Validate num_splits requires FA3
alt FA3 not selected and num_splits set
DotProductAttention-->>User: ValueError (FA3 required)
end
DotProductAttention->>FlashAttention: forward(..., num_splits=2)
FlashAttention->>FlashAttention: Check use_flash_attn_3 flag
alt num_splits set and not FA3
FlashAttention-->>DotProductAttention: ValueError (FA3 required)
end
alt FA3 enabled
FlashAttention->>FA3: flash_attn_*_v3(..., num_splits=2)
FA3-->>FlashAttention: output
else FA2 path (num_splits must be None)
FlashAttention->>FA2: flash_attn_func(...)
FA2-->>FlashAttention: output
end
FlashAttention-->>DotProductAttention: output
DotProductAttention-->>User: output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| if num_splits is not None: | ||
| # Forward optional split control to flash-attn if available | ||
| fa_optional_forward_kwargs["num_splits"] = num_splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Unlike other optional parameters (window_size has fa_utils.v2_3_plus, deterministic has fa_utils.v2_4_1_plus), num_splits is added to FA2 kwargs without version checking. If flash-attn v2 doesn't support this parameter, this will cause a TypeError at runtime.
| if num_splits is not None: | |
| # Forward optional split control to flash-attn if available | |
| fa_optional_forward_kwargs["num_splits"] = num_splits | |
| if num_splits is not None: | |
| # Only add num_splits if flash-attn supports it (check version if needed) | |
| # TODO: Add version check once minimum flash-attn version with num_splits is determined | |
| fa_optional_forward_kwargs["num_splits"] = num_splits |
| @pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"]) | ||
| def test_dpa_max_logit(dtype, model_configs, model, qkv_layout): | ||
| @pytest.mark.parametrize("num_splits", [None, 2]) | ||
| def test_dpa_max_logit(dtype, model_configs, model, qkv_layout, num_splits): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make the num_splits a separate test, instead of piggybacking on the max_logit test :) You can still call test_dot_product_attention in it the same way you do here. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
|
Could you please follow the instructions here to fix the DCO? Thanks! |
|
/te-ci pytorch L0 |
Deleted unused header Signed-off-by: Oleg Goncharov <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
…#2321) * L1 rework Signed-off-by: Phuong Nguyen <[email protected]> * comment out test_multi_process_grouped_gemm for now Signed-off-by: Phuong Nguyen <[email protected]> * rm e5m2 from test norm + MXFP8 Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
for more information, see https://pre-commit.ci Signed-off-by: Peter Dykas <[email protected]>
* code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix: Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]>
* code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]>
* fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
* Try to use pre-downloaded dataset artifacts first Signed-off-by: Jeremy Berchtold <[email protected]> * Set HF_HUB_OFFLINE to disable any network calls to HF when the pre-downloaded dataset is available Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
* Make cast_master_weights_to_fp8 compatible with older MCore version Signed-off-by: kunlunl <[email protected]> * Rename keep_columnwise to manual_post_all_gather_processing & Optimize unit test Signed-off-by: kunlunl <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove redundant _test_mini_optimizer() Signed-off-by: kunlunl <[email protected]> --------- Signed-off-by: kunlunl <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
…VIDIA#2348) * Add test to check jaxpr that amax is reused for nvfp4 recipe Signed-off-by: Jeremy Berchtold <[email protected]> * Move test to test_helper.py and rename file Signed-off-by: Jeremy Berchtold <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]>
…IA#2349) Signed-off-by: Peter Dykas <[email protected]>
* Fix cuDNN backend selection for more case. Add CG as a option as well Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix logic Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cuDNN checks Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add more checks Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cuddn version Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix error message Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add check for window size Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
* Default to fused attention in JAX DPA Signed-off-by: Kshitij Lakhani <[email protected]> * Consolidate documentation for DPA in JAX Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <[email protected]> * Correctly update the documentation for defaults in JAX DPA Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
…kernel type. (NVIDIA#2287) * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * depracted compile time warning + \warning -> \deprecated Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
* move triton to common and change paths Signed-off-by: tdophung <[email protected]> * Formatting Signed-off-by: tdophung <[email protected]> --------- Signed-off-by: tdophung <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
* Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False Signed-off-by: tdophung <[email protected]> * Fixing the failing tests by hard coding arguments to the previous values instead of relying on newer default values Signed-off-by: tdophung <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tdophung <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
2c90ec0 to
52a5f37
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
57 files reviewed, 1 comment
| if num_splits is not None and not use_flash_attn_3: | ||
| if not fa_utils.v3_is_installed: | ||
| raise ValueError( | ||
| "num_splits is only supported with FlashAttention-3, which is not installed. " | ||
| ) | ||
| raise ValueError( | ||
| "num_splits is only supported with FlashAttention-3. " | ||
| "Please adjust configuration to enable FA3 for these inputs." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The validation logic here correctly prevents num_splits from being used with non-FA3 backends. However, there's a potential issue with code maintainability: later in the code (line 939), num_splits is added to fa_optional_forward_kwargs without a version check, unlike other optional parameters. While the current validation prevents reaching that code with num_splits set and FA2, this creates fragile coupling between distant code sections. Consider either: (1) adding a version guard at line 939 similar to window_size and deterministic, or (2) adding an assertion that num_splits is None in the FA2 branch.
I think this is done? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| if num_splits is not None: | ||
| # Forward optional split control to flash-attn if available | ||
| fa_optional_forward_kwargs["num_splits"] = num_splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: This is unreachable dead code. The validation at line 848 already ensures num_splits is None when use_flash_attn_3 is False, so this condition can never be true inside the if not use_flash_attn_3: block. Consider removing this check for code clarity.
Description
I want to be able to control num splits in FA3. This exposes this argument for non-context-parallel cases.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: